from __future__ import print_function
import torch
import torch.nn as nn
import torchvision.utils as utils
from torchvision.utils import save_image
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
manualSeed = 999
torch.manual_seed(manualSeed)
dataroot = "/celeba-dataset/"
workers = 2
batch_size = 64
image_size = 64
nc = 3
nz = 100
ngf = 64
ndf = 64
num_epochs = 5
lr = 0.0002
beta1 = 0.5
dataset = dset.ImageFolder(root = dataroot,
transform = transforms.Compose(
[transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size = batch_size,
shuffle= True, num_workers= workers)
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(utils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
# Weight Initialization
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
# Generator
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, inputs):
return self.main(inputs)
generator = Generator()
generator.apply(weights_init)
print(generator)
# Discriminator
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf*2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf*2, ndf *4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf*4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf*8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf*8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, inputs):
return self.main(inputs)
discriminator = Discriminator()
discriminator.apply(weights_init)
print(discriminator)
Criterion = nn.BCELoss()
fixed_noise = torch.randn(64, nz, 1, 1)
real_label = 1
fake_label = 0
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
print("Starting Training Loop.....")
for epoch in range(num_epochs):
for i, data in enumerate(dataloader, 0):
discriminator.zero_grad()
real = data[0]
b_size = real.size(0)
label = torch.full((b_size,), real_label)
output = discriminator(real).view(-1)
error_d_real = Criterion(output, label)
error_d_real.backward()
D_x = output.mean().item()
noise = torch.randn(b_size, nz, 1, 1)
fake = generator(noise)
label.fill_(fake_label)
output = discriminator(fake.detach()).view(-1)
error_d_fake = Criterion(output, label)
error_d_fake.backward()
d_g_z1 = output.mean().item()
err_d = error_d_fake + error_d_real
optimizer_D.step()
generator.zero_grad()
label.fill_(real_label)
output = discriminator(fake).view(-1)
error_g = Criterion(output, label)
error_g.backward()
d_g_z2 = output.mean().item()
optimizer_G.step()
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader),
err_d.item(), error_g.item(), D_x, d_g_z1, d_g_z2))
# Save Losses for plotting later
G_losses.append(error_g.item())
D_losses.append(err_d.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = generator(fixed_noise).detach().cpu()
img_list.append(utils.make_grid(fake, padding=2, normalize=True))
iters += 1
torch.save(generator.state_dict(), 'generator_dcgan.pt')
torch.save(discriminator.state_dict(), 'discriminator_dcgan.pt')
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
HTML(ani.to_jshtml())